{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qwBHHt-XvPqn"
      },
      "source": [
        "# Plot MoViNet Video Stream Predictions\n",
        "\n",
        "This notebook uses [MoViNets (Mobile Video Networks)](https://github.com/tensorflow/models/tree/master/official/projects/movinet) to predict a human action in a streaming video and outputs a visualization of predictions on each frame.\n",
        "\n",
        "Provide a video URL or upload your own to see how predictions change over time. All models can be run on CPU.\n",
        "\n",
        "Pretrained models are provided by [TensorFlow Hub](https://tfhub.dev/google/collections/movinet/) and the [TensorFlow Model Garden](https://github.com/tensorflow/models/tree/master/official/projects/movinet), trained on [Kinetics 600](https://deepmind.com/research/open-source/kinetics) for video action classification. All Models use TensorFlow 2 with Keras for inference and training. See the [research paper](https://arxiv.org/pdf/2103.11511.pdf) for more details.\n",
        "\n",
        "Example output using [this gif](https://github.com/tensorflow/models/raw/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/jumpingjack.gif) as input:\n",
        "\n",
        "![jumping jacks plot](https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/jumpingjacks_plot.gif)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ElvELd9mIfZe"
      },
      "outputs": [],
      "source": [
        "#@title Run this cell to initialize and setup a [MoViNet](https://github.com/tensorflow/models/tree/master/official/projects/movinet) model.\n",
        "\n",
        "\n",
        "# Install the mediapy package for visualizing images/videos.\n",
        "# See https://github.com/google/mediapy\n",
        "!pip install -q mediapy\n",
        "\n",
        "# Run imports\n",
        "import os\n",
        "import io\n",
        "\n",
        "import matplotlib as mpl\n",
        "import matplotlib.pyplot as plt\n",
        "import mediapy as media\n",
        "import numpy as np\n",
        "import PIL\n",
        "import pandas as pd\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow_hub as hub\n",
        "import tqdm\n",
        "from google.colab import files\n",
        "import urllib.request\n",
        "\n",
        "mpl.rcParams.update({\n",
        "    'font.size': 10,\n",
        "})\n",
        "\n",
        "\n",
        "# Download Kinetics 600 label map\n",
        "!wget https://raw.githubusercontent.com/tensorflow/models/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/kinetics_600_labels.txt -O labels.txt -q\n",
        "\n",
        "with tf.io.gfile.GFile('labels.txt') as f:\n",
        "  lines = f.readlines()\n",
        "  KINETICS_600_LABELS_LIST = [line.strip() for line in lines]\n",
        "  KINETICS_600_LABELS = tf.constant(KINETICS_600_LABELS_LIST)\n",
        "\n",
        "def get_top_k(probs, k=5, label_map=KINETICS_600_LABELS):\n",
        "  \"\"\"Outputs the top k model labels and probabilities on the given video.\"\"\"\n",
        "  top_predictions = tf.argsort(probs, axis=-1, direction='DESCENDING')[:k]\n",
        "  top_labels = tf.gather(label_map, top_predictions, axis=-1)\n",
        "  top_labels = [label.decode('utf8') for label in top_labels.numpy()]\n",
        "  top_probs = tf.gather(probs, top_predictions, axis=-1).numpy()\n",
        "  return tuple(zip(top_labels, top_probs))\n",
        "\n",
        "def predict_top_k(model, video, k=5, label_map=KINETICS_600_LABELS):\n",
        "  \"\"\"Outputs the top k model labels and probabilities on the given video.\"\"\"\n",
        "  outputs = model.predict(video[tf.newaxis])[0]\n",
        "  probs = tf.nn.softmax(outputs)\n",
        "  return get_top_k(probs, k=k, label_map=label_map)\n",
        "\n",
        "def load_movinet_from_hub(model_id, model_mode, hub_version=3):\n",
        "  \"\"\"Loads a MoViNet model from TF Hub.\"\"\"\n",
        "  hub_url = f'https://tfhub.dev/tensorflow/movinet/{model_id}/{model_mode}/kinetics-600/classification/{hub_version}'\n",
        "\n",
        "  encoder = hub.KerasLayer(hub_url, trainable=True)\n",
        "\n",
        "  inputs = tf.keras.layers.Input(\n",
        "      shape=[None, None, None, 3],\n",
        "      dtype=tf.float32)\n",
        "\n",
        "  if model_mode == 'base':\n",
        "    inputs = dict(image=inputs)\n",
        "  else:\n",
        "    # Define the state inputs, which is a dict that maps state names to tensors.\n",
        "    init_states_fn = encoder.resolved_object.signatures['init_states']\n",
        "    state_shapes = {\n",
        "        name: ([s if s \u003e 0 else None for s in state.shape], state.dtype)\n",
        "        for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()\n",
        "    }\n",
        "    states_input = {\n",
        "        name: tf.keras.Input(shape[1:], dtype=dtype, name=name)\n",
        "        for name, (shape, dtype) in state_shapes.items()\n",
        "    }\n",
        "\n",
        "    # The inputs to the model are the states and the video\n",
        "    inputs = {**states_input, 'image': inputs}\n",
        "\n",
        "  # Output shape: [batch_size, 600]\n",
        "  outputs = encoder(inputs)\n",
        "\n",
        "  model = tf.keras.Model(inputs, outputs)\n",
        "  model.build([1, 1, 1, 1, 3])\n",
        "\n",
        "  return model\n",
        "\n",
        "# Download example gif\n",
        "!wget https://github.com/tensorflow/models/raw/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/jumpingjack.gif -O jumpingjack.gif -q\n",
        "\n",
        "def load_gif(file_path, image_size=(224, 224)):\n",
        "  \"\"\"Loads a gif file into a TF tensor.\"\"\"\n",
        "  with tf.io.gfile.GFile(file_path, 'rb') as f:\n",
        "    video = tf.io.decode_gif(f.read())\n",
        "  video = tf.image.resize(video, image_size)\n",
        "  video = tf.cast(video, tf.float32) / 255.\n",
        "  return video\n",
        "\n",
        "def get_top_k_streaming_labels(probs, k=5, label_map=KINETICS_600_LABELS_LIST):\n",
        "  \"\"\"Returns the top-k labels over an entire video sequence.\n",
        "\n",
        "  Args:\n",
        "    probs: probability tensor of shape (num_frames, num_classes) that represents\n",
        "      the probability of each class on each frame.\n",
        "    k: the number of top predictions to select.\n",
        "    label_map: a list of labels to map logit indices to label strings.\n",
        "\n",
        "  Returns:\n",
        "    a tuple of the top-k probabilities, labels, and logit indices\n",
        "  \"\"\"\n",
        "  top_categories_last = tf.argsort(probs, -1, 'DESCENDING')[-1, :1]\n",
        "  categories = tf.argsort(probs, -1, 'DESCENDING')[:, :k]\n",
        "  categories = tf.reshape(categories, [-1])\n",
        "\n",
        "  counts = sorted([\n",
        "      (i.numpy(), tf.reduce_sum(tf.cast(categories == i, tf.int32)).numpy())\n",
        "      for i in tf.unique(categories)[0]\n",
        "  ], key=lambda x: x[1], reverse=True)\n",
        "\n",
        "  top_probs_idx = tf.constant([i for i, _ in counts[:k]])\n",
        "  top_probs_idx = tf.concat([top_categories_last, top_probs_idx], 0)\n",
        "  top_probs_idx = tf.unique(top_probs_idx)[0][:k+1]\n",
        "\n",
        "  top_probs = tf.gather(probs, top_probs_idx, axis=-1)\n",
        "  top_probs = tf.transpose(top_probs, perm=(1, 0))\n",
        "  top_labels = tf.gather(label_map, top_probs_idx, axis=0)\n",
        "  top_labels = [label.decode('utf8') for label in top_labels.numpy()]\n",
        "\n",
        "  return top_probs, top_labels, top_probs_idx\n",
        "\n",
        "def plot_streaming_top_preds_at_step(\n",
        "    top_probs,\n",
        "    top_labels,\n",
        "    step=None,\n",
        "    image=None,\n",
        "    legend_loc='lower left',\n",
        "    duration_seconds=10,\n",
        "    figure_height=500,\n",
        "    playhead_scale=0.8,\n",
        "    grid_alpha=0.3):\n",
        "  \"\"\"Generates a plot of the top video model predictions at a given time step.\n",
        "\n",
        "  Args:\n",
        "    top_probs: a tensor of shape (k, num_frames) representing the top-k\n",
        "      probabilities over all frames.\n",
        "    top_labels: a list of length k that represents the top-k label strings.\n",
        "    step: the current time step in the range [0, num_frames].\n",
        "    image: the image frame to display at the current time step.\n",
        "    legend_loc: the placement location of the legend.\n",
        "    duration_seconds: the total duration of the video.\n",
        "    figure_height: the output figure height.\n",
        "    playhead_scale: scale value for the playhead.\n",
        "    grid_alpha: alpha value for the gridlines.\n",
        "\n",
        "  Returns:\n",
        "    A tuple of the output numpy image, figure, and axes.\n",
        "  \"\"\"\n",
        "  num_labels, num_frames = top_probs.shape\n",
        "  if step is None:\n",
        "    step = num_frames\n",
        "\n",
        "  fig = plt.figure(figsize=(6.5, 7), dpi=300)\n",
        "  gs = mpl.gridspec.GridSpec(8, 1)\n",
        "  ax2 = plt.subplot(gs[:-3, :])\n",
        "  ax = plt.subplot(gs[-3:, :])\n",
        "\n",
        "  if image is not None:\n",
        "    ax2.imshow(image, interpolation='nearest')\n",
        "    ax2.axis('off')\n",
        "\n",
        "  preview_line_x = tf.linspace(0., duration_seconds, num_frames)\n",
        "  preview_line_y = top_probs\n",
        "\n",
        "  line_x = preview_line_x[:step+1]\n",
        "  line_y = preview_line_y[:, :step+1]\n",
        "\n",
        "  for i in range(num_labels):\n",
        "    ax.plot(preview_line_x, preview_line_y[i], label=None, linewidth='1.5',\n",
        "            linestyle=':', color='gray')\n",
        "    ax.plot(line_x, line_y[i], label=top_labels[i], linewidth='2.0')\n",
        "\n",
        "\n",
        "  ax.grid(which='major', linestyle=':', linewidth='1.0', alpha=grid_alpha)\n",
        "  ax.grid(which='minor', linestyle=':', linewidth='0.5', alpha=grid_alpha)\n",
        "\n",
        "  min_height = tf.reduce_min(top_probs) * playhead_scale\n",
        "  max_height = tf.reduce_max(top_probs)\n",
        "  ax.vlines(preview_line_x[step], min_height, max_height, colors='red')\n",
        "  ax.scatter(preview_line_x[step], max_height, color='red')\n",
        "\n",
        "  ax.legend(loc=legend_loc)\n",
        "\n",
        "  plt.xlim(0, duration_seconds)\n",
        "  plt.ylabel('Probability')\n",
        "  plt.xlabel('Time (s)')\n",
        "  plt.yscale('log')\n",
        "\n",
        "  fig.tight_layout()\n",
        "  fig.canvas.draw()\n",
        "\n",
        "  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n",
        "  data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
        "  plt.close()\n",
        "\n",
        "  figure_width = int(figure_height * data.shape[1] / data.shape[0])\n",
        "  image = PIL.Image.fromarray(data).resize([figure_width, figure_height])\n",
        "  image = np.array(image)\n",
        "\n",
        "  return image, (fig, ax, ax2)\n",
        "\n",
        "def plot_streaming_top_preds(\n",
        "    probs,\n",
        "    video,\n",
        "    top_k=5,\n",
        "    video_fps=25.,\n",
        "    figure_height=500,\n",
        "    use_progbar=True):\n",
        "  \"\"\"Generates a video plot of the top video model predictions.\n",
        "\n",
        "  Args:\n",
        "    probs: probability tensor of shape (num_frames, num_classes) that represents\n",
        "      the probability of each class on each frame.\n",
        "    video: the video to display in the plot.\n",
        "    top_k: the number of top predictions to select.\n",
        "    video_fps: the input video fps.\n",
        "    figure_fps: the output video fps.\n",
        "    figure_height: the height of the output video.\n",
        "    use_progbar: display a progress bar.\n",
        "\n",
        "  Returns:\n",
        "    A numpy array representing the output video.\n",
        "  \"\"\"\n",
        "  video_fps = 8.\n",
        "  figure_height = 500\n",
        "  steps = video.shape[0]\n",
        "  duration = steps / video_fps\n",
        "\n",
        "  top_probs, top_labels, _ = get_top_k_streaming_labels(probs, k=top_k)\n",
        "\n",
        "  images = []\n",
        "  step_generator = tqdm.trange(steps) if use_progbar else range(steps)\n",
        "  for i in step_generator:\n",
        "    image, _ = plot_streaming_top_preds_at_step(\n",
        "        top_probs=top_probs,\n",
        "        top_labels=top_labels,\n",
        "        step=i,\n",
        "        image=video[i],\n",
        "        duration_seconds=duration,\n",
        "        figure_height=figure_height,\n",
        "    )\n",
        "    images.append(image)\n",
        "\n",
        "  return np.array(images)\n",
        "\n",
        "def generate_plot(\n",
        "    model,\n",
        "    video_url=None,\n",
        "    resolution=224,\n",
        "    video_fps=25,\n",
        "    display_fps=25):\n",
        "  # Load the video\n",
        "  if not video_url:\n",
        "    video_bytes = list(files.upload().values())[0]\n",
        "    with open('video', 'wb') as f:\n",
        "      f.write(video_bytes)\n",
        "  else:\n",
        "    urllib.request.urlretrieve(video_url, \"video\")\n",
        "\n",
        "  video = tf.cast(media.read_video('video'), tf.float32) / 255.\n",
        "  video = tf.image.resize(video, [resolution, resolution], preserve_aspect_ratio=True)\n",
        "\n",
        "  # Create initial states for the stream model\n",
        "  init_states_fn = model.layers[-1].resolved_object.signatures['init_states']\n",
        "  init_states = init_states_fn(tf.shape(video[tf.newaxis]))\n",
        "\n",
        "  clips = tf.split(video[tf.newaxis], video.shape[0], axis=1)\n",
        "\n",
        "  all_logits = []\n",
        "\n",
        "  print('Running the model on the video...')\n",
        "\n",
        "  # To run on a video, pass in one frame at a time\n",
        "  states = init_states\n",
        "  for clip in tqdm.tqdm(clips):\n",
        "    # Input shape: [1, 1, 172, 172, 3]\n",
        "    logits, states = model.predict({**states, 'image': clip}, verbose=0)\n",
        "    all_logits.append(logits)\n",
        "\n",
        "  logits = tf.concat(all_logits, 0)\n",
        "  probs = tf.nn.softmax(logits)\n",
        "\n",
        "  print('Generating the plot...')\n",
        "\n",
        "  # Generate a plot and output to a video tensor\n",
        "  plot_video = plot_streaming_top_preds(probs, video, video_fps=video_fps)\n",
        "  media.show_video(plot_video, fps=display_fps, codec='gif')\n",
        "\n",
        "model_size = 'm' #@param [\"xs\", \"s\", \"m\", \"l\", \"xl\", \"xxl\"]\n",
        "\n",
        "model_map = {\n",
        "    'xs': 'a0',\n",
        "    's': 'a1',\n",
        "    'm': 'a2',\n",
        "    'l': 'a3',\n",
        "    'xl': 'a4',\n",
        "    'xxl': 'a5',\n",
        "}\n",
        "movinet_model_id = model_map[model_size]\n",
        "\n",
        "model = load_movinet_from_hub(\n",
        "    movinet_model_id, 'stream', hub_version=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "jO6HrPk8pqo8"
      },
      "outputs": [],
      "source": [
        "#@title Generate a video plot.\n",
        "\n",
        "#@markdown You may add a video URL (gif or mp4) or leave the video_url field blank to upload your own file.\n",
        "video_url = \"https://i.pinimg.com/originals/33/5e/31/335e31bc8ed52511da0cfb4bc44e95c7.gif\"  #@param  {type:\"string\"}\n",
        "\n",
        "#@markdown The base input resolution to the model. A good value is 224, but can change based on model size.\n",
        "resolution = 224 #@param\n",
        "#@markdown The fps of the input video.\n",
        "video_fps = 12  #@param\n",
        "#@markdown The fps to display the output plot. Depending on the duration of the input video, it may help to use a lower fps.\n",
        "display_fps = 12  #@param\n",
        "\n",
        "generate_plot(\n",
        "    model,\n",
        "    video_url=video_url,\n",
        "    resolution=resolution,\n",
        "    video_fps=video_fps,\n",
        "    display_fps=display_fps)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "plot_movinet_video_stream_predictions.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
